{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings \n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 50, 768])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class Atten(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.q = torch.nn.Linear(768, 768, bias=True)\n",
    "        self.k = torch.nn.Linear(768, 768, bias=False)\n",
    "        self.v = torch.nn.Linear(768, 768, bias=True)\n",
    "        self.out = torch.nn.Linear(768, 768, bias=True)\n",
    "\n",
    "    def forward(self, x, mask):\n",
    "        b, lens, _ = x.size()\n",
    "\n",
    "        q = self.q(x) * 0.125\n",
    "        k = self.k(x)\n",
    "        v = self.v(x)\n",
    "\n",
    "        #[2, 50, 768] -> [2, 50, 12, 64] -> [2, 12, 50, 64] -> [24, 50, 64]\n",
    "        q = q.reshape(b, lens, 12, 64).transpose(1,\n",
    "                                                 2).reshape(b * 12, lens, 64)\n",
    "        k = k.reshape(b, lens, 12, 64).transpose(1,\n",
    "                                                 2).reshape(b * 12, lens, 64)\n",
    "        v = v.reshape(b, lens, 12, 64).transpose(1,\n",
    "                                                 2).reshape(b * 12, lens, 64)\n",
    "\n",
    "        #[24, 50, 64] * [24, 64, 50] -> [24, 50, 50] -> [2, 12, 50, 50]\n",
    "        atten = q.bmm(k.transpose(1, 2)).reshape(b, 12, lens, lens) + mask\n",
    "\n",
    "        #[2, 12, 50, 50] -> [24, 50, 50] * [24, 50, 64] -> [24, 50, 64]\n",
    "        atten = atten.reshape(b * 12, lens, lens).softmax(dim=-1).bmm(v)\n",
    "\n",
    "        #[24, 50, 64] -> [2, 12, 50, 64] -> [2, 50, 12, 64] -> [2, 50, 768]\n",
    "        atten = atten.reshape(b, 12, lens,\n",
    "                              64).transpose(1, 2).reshape(b, lens, 768)\n",
    "\n",
    "        return self.out(atten)\n",
    "\n",
    "\n",
    "Atten()(torch.randn(2, 50, 768), torch.ones(2, 1, 50, 50).long()).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 50, 768])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CrossAtten(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.q = torch.nn.Linear(768, 768, bias=True)\n",
    "        self.k = torch.nn.Linear(768, 768, bias=False)\n",
    "        self.v = torch.nn.Linear(768, 768, bias=True)\n",
    "\n",
    "        self.out = torch.nn.Linear(768, 768, bias=True)\n",
    "\n",
    "    def forward(self, x, kv):\n",
    "        b, lens, _ = x.size()\n",
    "\n",
    "        q = self.q(x) * 0.125\n",
    "        k = self.k(kv)\n",
    "        v = self.v(kv)\n",
    "\n",
    "        q = q.reshape(b, lens, 12, 64).transpose(1,\n",
    "                                                 2).reshape(b * 12, lens, 64)\n",
    "        k = k.reshape(b, 1500, 12, 64).transpose(1,\n",
    "                                                 2).reshape(b * 12, 1500, 64)\n",
    "        v = v.reshape(b, 1500, 12, 64).transpose(1,\n",
    "                                                 2).reshape(b * 12, 1500, 64)\n",
    "\n",
    "        atten = q.bmm(k.transpose(1, 2)).softmax(dim=-1).bmm(v)\n",
    "        atten = atten.reshape(b, 12, lens,\n",
    "                              64).transpose(1, 2).reshape(b, lens, 768)\n",
    "\n",
    "        return self.out(atten)\n",
    "\n",
    "\n",
    "CrossAtten()(torch.randn(2, 50, 768), torch.randn(2, 1500, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 50, 768])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Layer(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.norm1 = torch.nn.LayerNorm(768)\n",
    "        self.atten = Atten()\n",
    "\n",
    "        self.norm2 = torch.nn.LayerNorm(768)\n",
    "        self.cross_atten = CrossAtten()\n",
    "\n",
    "        self.s = torch.nn.Sequential(\n",
    "            torch.nn.LayerNorm(768),\n",
    "            torch.nn.Linear(768, 3072),\n",
    "            torch.torch.nn.GELU(),\n",
    "            torch.nn.Linear(3072, 768),\n",
    "        )\n",
    "\n",
    "    def forward(self, x, mask, kv):\n",
    "        x = self.atten(self.norm1(x), mask=mask) + x\n",
    "        x = self.cross_atten(self.norm2(x), kv=kv) + x\n",
    "\n",
    "        return self.s(x) + x\n",
    "\n",
    "\n",
    "Layer()(torch.randn(2, 50, 768), torch.ones(2, 1, 50, 50).long(),\n",
    "        torch.randn(2, 1500, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0., -inf, -inf, -inf, -inf],\n",
       "          [0., 0., -inf, -inf, -inf],\n",
       "          [0., 0., 0., -inf, -inf],\n",
       "          [0., 0., 0., 0., -inf],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., -inf, -inf, -inf, -inf],\n",
       "          [0., 0., -inf, -inf, -inf],\n",
       "          [0., 0., 0., -inf, -inf],\n",
       "          [0., 0., 0., 0., -inf],\n",
       "          [0., 0., 0., 0., 0.]]]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_mask(b, lens):\n",
    "    mask = torch.full((lens, lens), -float('inf'))\n",
    "\n",
    "    t = torch.arange(lens)\n",
    "    t = t < (t + 1).reshape(lens, 1)\n",
    "    mask.masked_fill_(t, 0.0)\n",
    "\n",
    "    return mask.reshape(1, 1, lens, lens).repeat(b, 1, 1, 1)\n",
    "\n",
    "\n",
    "get_mask(2, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 50, 768])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Decoder(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.embed = torch.nn.Embedding(51865, 768, 50257)\n",
    "        self.embed_pos = torch.nn.Embedding(448, 768)\n",
    "\n",
    "        self.layer = torch.nn.ModuleList([Layer() for _ in range(12)])\n",
    "        self.norm = torch.nn.LayerNorm(768)\n",
    "\n",
    "    def forward(self, x, kv):\n",
    "        mask = get_mask(*x.shape).to(x.device)\n",
    "        \n",
    "        x = self.embed(x) + self.embed_pos.weight[:x.shape[1]]\n",
    "\n",
    "        for i in self.layer:\n",
    "            x = i(x, mask=mask, kv=kv)\n",
    "\n",
    "        return self.norm(x)\n",
    "\n",
    "\n",
    "Decoder()(torch.ones(2, 50).long(), kv=torch.randn(2, 1500, 768)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_decoder(pretrained):\n",
    "    decoder = Decoder()\n",
    "\n",
    "    decoder.embed.load_state_dict(pretrained.embed_tokens.state_dict())\n",
    "    decoder.embed_pos.load_state_dict(pretrained.embed_positions.state_dict())\n",
    "    decoder.norm.load_state_dict(pretrained.layer_norm.state_dict())\n",
    "\n",
    "    for i in range(12):\n",
    "\n",
    "        decoder.layer[i].norm1.load_state_dict(\n",
    "            pretrained.layers[i].self_attn_layer_norm.state_dict())\n",
    "\n",
    "        decoder.layer[i].atten.q.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.q_proj.state_dict())\n",
    "        decoder.layer[i].atten.k.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.k_proj.state_dict())\n",
    "        decoder.layer[i].atten.v.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.v_proj.state_dict())\n",
    "        decoder.layer[i].atten.out.load_state_dict(\n",
    "            pretrained.layers[i].self_attn.out_proj.state_dict())\n",
    "\n",
    "        decoder.layer[i].norm2.load_state_dict(\n",
    "            pretrained.layers[i].encoder_attn_layer_norm.state_dict())\n",
    "\n",
    "        decoder.layer[i].cross_atten.q.load_state_dict(\n",
    "            pretrained.layers[i].encoder_attn.q_proj.state_dict())\n",
    "        decoder.layer[i].cross_atten.k.load_state_dict(\n",
    "            pretrained.layers[i].encoder_attn.k_proj.state_dict())\n",
    "        decoder.layer[i].cross_atten.v.load_state_dict(\n",
    "            pretrained.layers[i].encoder_attn.v_proj.state_dict())\n",
    "        decoder.layer[i].cross_atten.out.load_state_dict(\n",
    "            pretrained.layers[i].encoder_attn.out_proj.state_dict())\n",
    "\n",
    "        decoder.layer[i].s[0].load_state_dict(\n",
    "            pretrained.layers[i].final_layer_norm.state_dict())\n",
    "\n",
    "        decoder.layer[i].s[1].load_state_dict(\n",
    "            pretrained.layers[i].fc1.state_dict())\n",
    "        decoder.layer[i].s[3].load_state_dict(\n",
    "            pretrained.layers[i].fc2.state_dict())\n",
    "\n",
    "    return decoder\n",
    "\n",
    "\n",
    "from transformers import WhisperForConditionalGeneration\n",
    "\n",
    "pretrained = WhisperForConditionalGeneration.from_pretrained(\n",
    "    'models/whisper-small').model.decoder\n",
    "decoder = load_decoder(pretrained)\n",
    "\n",
    "x = torch.ones(2, 50).long()\n",
    "kv = torch.randn(2, 1500, 768)\n",
    "\n",
    "out1 = decoder(x, kv)\n",
    "out2 = pretrained(input_ids=x, attention_mask=None,\n",
    "                  encoder_hidden_states=kv).last_hidden_state\n",
    "\n",
    "(out1 == out2).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simple",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
